"""Assist satellite entity for VoIP integration."""

from __future__ import annotations

import asyncio
from enum import IntFlag
from functools import partial
import io
import logging
from pathlib import Path
import socket
import time
from typing import TYPE_CHECKING, Any, Final
import wave

from voip_utils import SIP_PORT, RtpDatagramProtocol
from voip_utils.sip import SipDatagramProtocol, SipEndpoint, get_sip_endpoint

from homeassistant.components import tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
    AssistSatelliteAnnouncement,
    AssistSatelliteConfiguration,
    AssistSatelliteEntity,
    AssistSatelliteEntityDescription,
    AssistSatelliteEntityFeature,
)
from homeassistant.components.network import async_get_source_ip
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback

from .const import (
    CHANNELS,
    CONF_SIP_PORT,
    CONF_SIP_USER,
    DOMAIN,
    RATE,
    RTP_AUDIO_SETTINGS,
    WIDTH,
)
from .devices import VoIPDevice
from .entity import VoIPEntity

if TYPE_CHECKING:
    from . import DomainData

_LOGGER = logging.getLogger(__name__)

_PIPELINE_TIMEOUT_SEC: Final = 30
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5
_ANNOUNCEMENT_RING_TIMEOUT: Final = 30


class Tones(IntFlag):
    """Feedback tones for specific events."""

    LISTENING = 1
    PROCESSING = 2
    ERROR = 4


_TONE_FILENAMES: dict[Tones, str] = {
    Tones.LISTENING: "tone.pcm",
    Tones.PROCESSING: "processing.pcm",
    Tones.ERROR: "error.pcm",
}


async def async_setup_entry(
    hass: HomeAssistant,
    config_entry: ConfigEntry,
    async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
    """Set up VoIP Assist satellite entity."""
    domain_data: DomainData = hass.data[DOMAIN]

    @callback
    def async_add_device(device: VoIPDevice) -> None:
        """Add device."""
        async_add_entities([VoipAssistSatellite(hass, device, config_entry)])

    domain_data.devices.async_add_new_device_listener(async_add_device)

    entities: list[VoIPEntity] = [
        VoipAssistSatellite(hass, device, config_entry)
        for device in domain_data.devices
    ]

    async_add_entities(entities)


class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol):
    """Assist satellite for VoIP devices."""

    entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
    _attr_translation_key = "assist_satellite"
    _attr_name = None
    _attr_supported_features = (
        AssistSatelliteEntityFeature.ANNOUNCE
        | AssistSatelliteEntityFeature.START_CONVERSATION
    )

    def __init__(
        self,
        hass: HomeAssistant,
        voip_device: VoIPDevice,
        config_entry: ConfigEntry,
        tones=Tones.LISTENING | Tones.PROCESSING | Tones.ERROR,
    ) -> None:
        """Initialize an Assist satellite."""
        VoIPEntity.__init__(self, voip_device)
        AssistSatelliteEntity.__init__(self)
        RtpDatagramProtocol.__init__(self)

        self.config_entry = config_entry

        self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
        self._audio_chunk_timeout: float = 2.0
        self._run_pipeline_task: asyncio.Task | None = None
        self._pipeline_had_error: bool = False
        self._tts_done = asyncio.Event()
        self._tts_extra_timeout: float = 1.0
        self._tone_bytes: dict[Tones, bytes] = {}
        self._tones = tones
        self._processing_tone_done = asyncio.Event()

        self._announcement: AssistSatelliteAnnouncement | None = None
        self._announcement_future: asyncio.Future[Any] = asyncio.Future()
        self._announcment_start_time: float = 0.0
        self._check_announcement_ended_task: asyncio.Task | None = None
        self._last_chunk_time: float | None = None
        self._rtp_port: int | None = None
        self._run_pipeline_after_announce: bool = False

    @property
    def pipeline_entity_id(self) -> str | None:
        """Return the entity ID of the pipeline to use for the next conversation."""
        return self.voip_device.get_pipeline_entity_id(self.hass)

    @property
    def vad_sensitivity_entity_id(self) -> str | None:
        """Return the entity ID of the VAD sensitivity to use for the next conversation."""
        return self.voip_device.get_vad_sensitivity_entity_id(self.hass)

    @property
    def tts_options(self) -> dict[str, Any] | None:
        """Options passed for text-to-speech."""
        return {
            tts.ATTR_PREFERRED_FORMAT: "wav",
            tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
            tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
            tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
        }

    async def async_added_to_hass(self) -> None:
        """Run when entity about to be added to hass."""
        await super().async_added_to_hass()
        self.voip_device.protocol = self

    async def async_will_remove_from_hass(self) -> None:
        """Run when entity will be removed from hass."""
        await super().async_will_remove_from_hass()
        assert self.voip_device.protocol == self
        self.voip_device.protocol = None

    @callback
    def async_get_configuration(
        self,
    ) -> AssistSatelliteConfiguration:
        """Get the current satellite configuration."""
        raise NotImplementedError

    async def async_set_configuration(
        self, config: AssistSatelliteConfiguration
    ) -> None:
        """Set the current satellite configuration."""
        raise NotImplementedError

    async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
        """Announce media on the satellite.

        Plays announcement in a loop, blocking until the caller hangs up.
        """
        await self._do_announce(announcement, run_pipeline_after=False)

    async def _do_announce(
        self, announcement: AssistSatelliteAnnouncement, run_pipeline_after: bool
    ) -> None:
        """Announce media on the satellite.

        Optionally run a voice pipeline after the announcement has finished.
        """
        self._announcement_future = asyncio.Future()
        self._run_pipeline_after_announce = run_pipeline_after

        if self._rtp_port is None:
            # Choose random port for RTP
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.setblocking(False)
            sock.bind(("", 0))
            _rtp_ip, self._rtp_port = sock.getsockname()
            sock.close()

        # HA SIP server
        source_ip = await async_get_source_ip(self.hass)
        sip_port = self.config_entry.options.get(CONF_SIP_PORT, SIP_PORT)
        sip_user = self.config_entry.options.get(CONF_SIP_USER)
        source_endpoint = get_sip_endpoint(
            host=source_ip, port=sip_port, username=sip_user
        )

        try:
            # VoIP ID is SIP header
            destination_endpoint = SipEndpoint(self.voip_device.voip_id)
        except ValueError:
            # VoIP ID is IP address
            destination_endpoint = get_sip_endpoint(
                host=self.voip_device.voip_id, port=SIP_PORT
            )

        # Reset state so we can time out if needed
        self._last_chunk_time = None
        self._announcment_start_time = time.monotonic()
        self._announcement = announcement

        # Make the call
        sip_protocol: SipDatagramProtocol = self.hass.data[DOMAIN].protocol
        call_info = sip_protocol.outgoing_call(
            source=source_endpoint,
            destination=destination_endpoint,
            rtp_port=self._rtp_port,
        )

        # Check if caller hung up or didn't pick up
        self._check_announcement_ended_task = (
            self.config_entry.async_create_background_task(
                self.hass,
                self._check_announcement_ended(),
                "voip_announcement_ended",
            )
        )

        try:
            await self._announcement_future
        except TimeoutError:
            # Stop ringing
            sip_protocol.cancel_call(call_info)
            raise

    async def _check_announcement_ended(self) -> None:
        """Continuously checks if an audio chunk was received within a time limit.

        If not, the caller is presumed to have hung up and the announcement is ended.
        """
        while self._announcement is not None:
            current_time = time.monotonic()
            if (self._last_chunk_time is None) and (
                (current_time - self._announcment_start_time)
                > _ANNOUNCEMENT_RING_TIMEOUT
            ):
                # Ring timeout
                self._announcement = None
                self._check_announcement_ended_task = None
                self._announcement_future.set_exception(
                    TimeoutError("User did not pick up in time")
                )
                _LOGGER.debug("Timed out waiting for the user to pick up the phone")
                break

            if (self._last_chunk_time is not None) and (
                (current_time - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC
            ):
                # Caller hung up
                self._announcement = None
                self._announcement_future.set_result(None)
                self._check_announcement_ended_task = None
                _LOGGER.debug("Announcement ended")
                break

            await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)

    async def async_start_conversation(
        self, start_announcement: AssistSatelliteAnnouncement
    ) -> None:
        """Start a conversation from the satellite."""
        await self._do_announce(start_announcement, run_pipeline_after=True)

    # -------------------------------------------------------------------------
    # VoIP
    # -------------------------------------------------------------------------

    def on_chunk(self, audio_bytes: bytes) -> None:
        """Handle raw audio chunk."""
        self._last_chunk_time = time.monotonic()

        if self._announcement is None:
            # Pipeline with STT
            if self._run_pipeline_task is None:
                # Run pipeline until voice command finishes, then start over
                self._clear_audio_queue()
                self._tts_done.clear()
                self._run_pipeline_task = (
                    self.config_entry.async_create_background_task(
                        self.hass,
                        self._run_pipeline(),
                        "voip_pipeline_run",
                    )
                )

            self._audio_queue.put_nowait(audio_bytes)
        elif self._run_pipeline_task is None:
            # Announcement only
            # Play announcement (will repeat)
            self._run_pipeline_task = self.config_entry.async_create_background_task(
                self.hass,
                self._play_announcement(self._announcement),
                "voip_play_announcement",
            )

    async def _run_pipeline(self) -> None:
        """Run a pipeline with STT input and TTS output."""
        _LOGGER.debug("Starting pipeline")

        self.async_set_context(Context(user_id=self.config_entry.data["user"]))
        self.voip_device.set_is_active(True)

        async def stt_stream():
            while True:
                async with asyncio.timeout(self._audio_chunk_timeout):
                    chunk = await self._audio_queue.get()
                    if not chunk:
                        break

                    yield chunk

        # Play listening tone at the start of each cycle
        await self._play_tone(Tones.LISTENING, silence_before=0.2)

        try:
            await self.async_accept_pipeline_from_satellite(
                audio_stream=stt_stream(),
            )

            if self._pipeline_had_error:
                self._pipeline_had_error = False
                await self._play_tone(Tones.ERROR)
            else:
                # Block until TTS is done speaking.
                #
                # This is set in _send_tts and has a timeout that's based on the
                # length of the TTS audio.
                await self._tts_done.wait()
        except TimeoutError:
            self.disconnect()  # caller hung up
        finally:
            # Stop audio stream
            await self._audio_queue.put(None)

            self.voip_device.set_is_active(False)
            self._run_pipeline_task = None
            _LOGGER.debug("Pipeline finished")

    async def _play_announcement(
        self, announcement: AssistSatelliteAnnouncement
    ) -> None:
        """Play an announcement once."""
        _LOGGER.debug("Playing announcement")

        try:
            await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
            await self._send_tts(announcement.original_media_id, wait_for_tone=False)

            if not self._run_pipeline_after_announce:
                # Delay before looping announcement
                await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
        except Exception:
            _LOGGER.exception("Unexpected error while playing announcement")
            raise
        finally:
            self._run_pipeline_task = None
            _LOGGER.debug("Announcement finished")

            if self._run_pipeline_after_announce:
                # Clear announcement to allow pipeline to run
                self._announcement = None
                self._announcement_future.set_result(None)

    def _clear_audio_queue(self) -> None:
        """Ensure audio queue is empty."""
        while not self._audio_queue.empty():
            self._audio_queue.get_nowait()

    def on_pipeline_event(self, event: PipelineEvent) -> None:
        """Set state based on pipeline stage."""
        if event.type == PipelineEventType.STT_END:
            if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
                self._processing_tone_done.clear()
                self.config_entry.async_create_background_task(
                    self.hass, self._play_tone(Tones.PROCESSING), "voip_process_tone"
                )
        elif event.type == PipelineEventType.TTS_END:
            # Send TTS audio to caller over RTP
            if event.data and (tts_output := event.data["tts_output"]):
                media_id = tts_output["media_id"]
                self.config_entry.async_create_background_task(
                    self.hass,
                    self._send_tts(media_id),
                    "voip_pipeline_tts",
                )
            else:
                # Empty TTS response
                self._tts_done.set()
        elif event.type == PipelineEventType.ERROR:
            # Play error tone instead of wait for TTS when pipeline is finished.
            self._pipeline_had_error = True
            _LOGGER.warning(event)

    async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
        """Send TTS audio to caller via RTP."""
        try:
            if self.transport is None:
                return  # not connected

            extension, data = await tts.async_get_media_source_audio(
                self.hass,
                media_id,
            )

            if extension != "wav":
                raise ValueError(f"Only WAV audio can be streamed, got {extension}")

            if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
                # Don't overlap TTS and processing beep
                _LOGGER.debug("Waiting for processing tone")
                await self._processing_tone_done.wait()

            with io.BytesIO(data) as wav_io:
                with wave.open(wav_io, "rb") as wav_file:
                    sample_rate = wav_file.getframerate()
                    sample_width = wav_file.getsampwidth()
                    sample_channels = wav_file.getnchannels()

                    if (
                        (sample_rate != RATE)
                        or (sample_width != WIDTH)
                        or (sample_channels != CHANNELS)
                    ):
                        raise ValueError(
                            f"Expected rate/width/channels as {RATE}/{WIDTH}/{CHANNELS},"
                            f" got {sample_rate}/{sample_width}/{sample_channels}"
                        )

                audio_bytes = wav_file.readframes(wav_file.getnframes())

            _LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))

            # Time out 1 second after TTS audio should be finished
            tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
            tts_seconds = tts_samples / RATE

            async with asyncio.timeout(tts_seconds + self._tts_extra_timeout):
                # TTS audio is 16Khz 16-bit mono
                await self._async_send_audio(audio_bytes)
        except TimeoutError:
            _LOGGER.warning("TTS timeout")
            raise
        finally:
            # Update satellite state
            self.tts_response_finished()

            # Signal pipeline to restart
            self._tts_done.set()

    async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
        """Send audio in executor."""
        await self.hass.async_add_executor_job(
            partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
        )

    async def _play_tone(self, tone: Tones, silence_before: float = 0.0) -> None:
        """Play a tone as feedback to the user if it's enabled."""
        if (self._tones & tone) != tone:
            return  # not enabled

        if tone not in self._tone_bytes:
            # Do I/O in executor
            self._tone_bytes[tone] = await self.hass.async_add_executor_job(
                self._load_pcm,
                _TONE_FILENAMES[tone],
            )

        await self._async_send_audio(
            self._tone_bytes[tone],
            silence_before=silence_before,
        )

        if tone == Tones.PROCESSING:
            self._processing_tone_done.set()

    def _load_pcm(self, file_name: str) -> bytes:
        """Load raw audio (16Khz, 16-bit mono)."""
        return (Path(__file__).parent / file_name).read_bytes()
